{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d43c9cdc-530d-40df-b9f0-adf02c00c948",
   "metadata": {},
   "source": [
    "# Lesson 5: Adding Relationships to the SEC Knowledge Graph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8d5a277-ce22-4199-a3dc-4820a897a9f5",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fd4a6180; padding:15px; margin-left:20px\"> <b>Note:</b> This notebook takes about 30 seconds to be ready to use. Please wait until the \"Kernel starting, please wait...\" message clears from the top of the notebook before running any cells. You may start the video while you wait.</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e7f7ac8-5131-4e20-bd93-d727b747bfd8",
   "metadata": {},
   "source": [
    "### Import packages and set up Neo4j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c1eab9e-cdbe-4625-8f71-8d13b760e513",
   "metadata": {
    "height": 318
   },
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv\n",
    "import os\n",
    "\n",
    "# Common data processing\n",
    "import textwrap\n",
    "\n",
    "# Langchain\n",
    "from langchain_community.graphs import Neo4jGraph\n",
    "from langchain_community.vectorstores import Neo4jVector\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain.chains import RetrievalQAWithSourcesChain\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "\n",
    "# Warning control\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48baf0cd-e824-4d23-ab52-3c233eee21e0",
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "# Load from environment\n",
    "load_dotenv('.env', override=True)\n",
    "NEO4J_URI = os.getenv('NEO4J_URI')\n",
    "NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')\n",
    "NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')\n",
    "NEO4J_DATABASE = os.getenv('NEO4J_DATABASE') or 'neo4j'\n",
    "\n",
    "# Global constants\n",
    "VECTOR_INDEX_NAME = 'form_10k_chunks'\n",
    "VECTOR_NODE_LABEL = 'Chunk'\n",
    "VECTOR_SOURCE_PROPERTY = 'text'\n",
    "VECTOR_EMBEDDING_PROPERTY = 'textEmbedding'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8dbc048-131c-422a-aeb7-78014a8e60e0",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "kg = Neo4jGraph(\n",
    "    url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD, database=NEO4J_DATABASE\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96614653-85cc-417c-8adc-a3eb9be063c4",
   "metadata": {},
   "source": [
    "### Create a Form 10-K node\n",
    "- Create a node to represent the entire Form 10-K\n",
    "- Populate with metadata taken from a single chunk of the form"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a1776cd-f739-4c57-86f1-fd1a306b9a54",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (anyChunk:Chunk) \n",
    "  WITH anyChunk LIMIT 1\n",
    "  RETURN anyChunk { .names, .source, .formId, .cik, .cusip6 } as formInfo\n",
    "\"\"\"\n",
    "form_info_list = kg.query(cypher)\n",
    "\n",
    "form_info_list\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb73062-e4e7-44fe-a2f8-6701f2b0ce96",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "form_info = form_info_list[0]['formInfo']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec337f02-c1bc-4ea9-b1ea-775dad0287ba",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "form_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c137e5f0-a97c-4c47-afaf-097da3d78680",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "    MERGE (f:Form {formId: $formInfoParam.formId })\n",
    "      ON CREATE \n",
    "        SET f.names = $formInfoParam.names\n",
    "        SET f.source = $formInfoParam.source\n",
    "        SET f.cik = $formInfoParam.cik\n",
    "        SET f.cusip6 = $formInfoParam.cusip6\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={'formInfoParam': form_info})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4891b00-ab1a-4c81-8593-612f53f53340",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "kg.query(\"MATCH (f:Form) RETURN count(f) as formCount\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f834ad8a-6f42-4732-9b54-f710785b4131",
   "metadata": {},
   "source": [
    "### Create a linked list of Chunk nodes for each section\n",
    "- Start by identifying chunks from the same section"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c40e49c-83ac-4b11-af31-c921f81a7d50",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (from_same_form:Chunk)\n",
    "    WHERE from_same_form.formId = $formIdParam\n",
    "  RETURN from_same_form {.formId, .f10kItem, .chunkId, .chunkSeqId } as chunkInfo\n",
    "    LIMIT 10\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={'formIdParam': form_info['formId']})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b4ccd82-cdb0-4091-855e-4146ffdcb6a0",
   "metadata": {},
   "source": [
    "- Order chunks by their sequence ID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4593c853-8546-48ac-a1cc-e30a9498fb68",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (from_same_form:Chunk)\n",
    "    WHERE from_same_form.formId = $formIdParam\n",
    "  RETURN from_same_form {.formId, .f10kItem, .chunkId, .chunkSeqId } as chunkInfo \n",
    "    ORDER BY from_same_form.chunkSeqId ASC\n",
    "    LIMIT 10\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={'formIdParam': form_info['formId']})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc90b9a5-497b-422c-a134-86be5cab1c4f",
   "metadata": {},
   "source": [
    "- Limit chunks to just the \"Item 1\" section, the organize in ascending order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ebe1d0-64dc-4bb4-9f4f-0bd5e3aedcc5",
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (from_same_section:Chunk)\n",
    "  WHERE from_same_section.formId = $formIdParam\n",
    "    AND from_same_section.f10kItem = $f10kItemParam // NEW!!!\n",
    "  RETURN from_same_section { .formId, .f10kItem, .chunkId, .chunkSeqId } \n",
    "    ORDER BY from_same_section.chunkSeqId ASC\n",
    "    LIMIT 10\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={'formIdParam': form_info['formId'], \n",
    "                         'f10kItemParam': 'item1'})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93dcbbf5-8403-4f12-b663-e9caba70df6f",
   "metadata": {},
   "source": [
    "- Collect ordered chunks into a list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f672e931-000a-4b72-a2e4-c616c03ee190",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (from_same_section:Chunk)\n",
    "  WHERE from_same_section.formId = $formIdParam\n",
    "    AND from_same_section.f10kItem = $f10kItemParam\n",
    "  WITH from_same_section { .formId, .f10kItem, .chunkId, .chunkSeqId } \n",
    "    ORDER BY from_same_section.chunkSeqId ASC\n",
    "    LIMIT 10\n",
    "  RETURN collect(from_same_section) // NEW!!!\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={'formIdParam': form_info['formId'], \n",
    "                         'f10kItemParam': 'item1'})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dcc47e6-5e54-4123-8721-37bac41a882f",
   "metadata": {},
   "source": [
    "### Add a NEXT relationship between subsequent chunks\n",
    "- Use the `apoc.nodes.link` function from Neo4j to link ordered list of `Chunk` nodes with a `NEXT` relationship\n",
    "- Do this for just the \"Item 1\" section to start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73ad3e35-955f-43d1-a33a-9b6648ac21fa",
   "metadata": {
    "height": 318
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (from_same_section:Chunk)\n",
    "  WHERE from_same_section.formId = $formIdParam\n",
    "    AND from_same_section.f10kItem = $f10kItemParam\n",
    "  WITH from_same_section\n",
    "    ORDER BY from_same_section.chunkSeqId ASC\n",
    "  WITH collect(from_same_section) as section_chunk_list\n",
    "    CALL apoc.nodes.link(\n",
    "        section_chunk_list, \n",
    "        \"NEXT\", \n",
    "        {avoidDuplicates: true}\n",
    "    )  // NEW!!!\n",
    "  RETURN size(section_chunk_list)\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={'formIdParam': form_info['formId'], \n",
    "                         'f10kItemParam': 'item1'})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "853366b4-d32e-43cb-8ca5-8e2cb74b474b",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "kg.refresh_schema()\n",
    "print(kg.schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6b23340-9a01-4680-824a-c9098db6b72c",
   "metadata": {},
   "source": [
    "- Loop through and create relationships for all sections of the form 10-K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c54a4ee-32aa-4935-af94-8fdafe75b4b5",
   "metadata": {
    "height": 318
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (from_same_section:Chunk)\n",
    "  WHERE from_same_section.formId = $formIdParam\n",
    "    AND from_same_section.f10kItem = $f10kItemParam\n",
    "  WITH from_same_section\n",
    "    ORDER BY from_same_section.chunkSeqId ASC\n",
    "  WITH collect(from_same_section) as section_chunk_list\n",
    "    CALL apoc.nodes.link(\n",
    "        section_chunk_list, \n",
    "        \"NEXT\", \n",
    "        {avoidDuplicates: true}\n",
    "    )\n",
    "  RETURN size(section_chunk_list)\n",
    "\"\"\"\n",
    "for form10kItemName in ['item1', 'item1a', 'item7', 'item7a']:\n",
    "  kg.query(cypher, params={'formIdParam':form_info['formId'], \n",
    "                           'f10kItemParam': form10kItemName})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f05e1db-76d7-4115-8757-6bce05ac4c45",
   "metadata": {},
   "source": [
    "### Connect chunks to their parent form with a PART_OF relationship"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c985b701-045f-4277-a86d-c255f9fae380",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (c:Chunk), (f:Form)\n",
    "    WHERE c.formId = f.formId\n",
    "  MERGE (c)-[newRelationship:PART_OF]->(f)\n",
    "  RETURN count(newRelationship)\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3bb02a9-7256-4dcd-9a54-8a8a511993bc",
   "metadata": {},
   "source": [
    "### Create a SECTION relationship on first chunk of each section"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38ae7ddf-d73a-4979-b4e0-d1ea5a59ab01",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (first:Chunk), (f:Form)\n",
    "  WHERE first.formId = f.formId\n",
    "    AND first.chunkSeqId = 0\n",
    "  WITH first, f\n",
    "    MERGE (f)-[r:SECTION {f10kItem: first.f10kItem}]->(first)\n",
    "  RETURN count(r)\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df387da0-1d2b-4eaf-a39c-6ebe3ab70a3f",
   "metadata": {},
   "source": [
    "### Example cypher queries\n",
    "- Return the first chunk of the Item 1 section"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a52b8ed-a19a-448c-9994-bd101ecc4142",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (f:Form)-[r:SECTION]->(first:Chunk)\n",
    "    WHERE f.formId = $formIdParam\n",
    "        AND r.f10kItem = $f10kItemParam\n",
    "  RETURN first.chunkId as chunkId, first.text as text\n",
    "\"\"\"\n",
    "\n",
    "first_chunk_info = kg.query(cypher, params={\n",
    "    'formIdParam': form_info['formId'], \n",
    "    'f10kItemParam': 'item1'\n",
    "})[0]\n",
    "\n",
    "first_chunk_info\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b46c3292-7562-463b-b265-d2d28ec4d151",
   "metadata": {},
   "source": [
    "- Get the second chunk of the Item 1 section"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10c7240b-8b8f-41c9-861f-007e62490838",
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (first:Chunk)-[:NEXT]->(nextChunk:Chunk)\n",
    "    WHERE first.chunkId = $chunkIdParam\n",
    "  RETURN nextChunk.chunkId as chunkId, nextChunk.text as text\n",
    "\"\"\"\n",
    "\n",
    "next_chunk_info = kg.query(cypher, params={\n",
    "    'chunkIdParam': first_chunk_info['chunkId']\n",
    "})[0]\n",
    "\n",
    "next_chunk_info\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f137f111-34b3-48d5-95eb-363292251a5a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "print(first_chunk_info['chunkId'], next_chunk_info['chunkId'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a82e802-76a1-43fd-9017-29030d2a2299",
   "metadata": {},
   "source": [
    "- Return a window of three chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4385dcb-e912-45e6-9f7d-c64166ebaa79",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "    MATCH (c1:Chunk)-[:NEXT]->(c2:Chunk)-[:NEXT]->(c3:Chunk) \n",
    "        WHERE c2.chunkId = $chunkIdParam\n",
    "    RETURN c1.chunkId, c2.chunkId, c3.chunkId\n",
    "    \"\"\"\n",
    "\n",
    "kg.query(cypher,\n",
    "         params={'chunkIdParam': next_chunk_info['chunkId']})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6ac2448-b9b9-4b2d-b5b8-9ccf78dfa8a0",
   "metadata": {},
   "source": [
    "### Information is stored in the structure of a graph\n",
    "- Matched patterns of nodes and relationships in a graph are called **paths**\n",
    "- The length of a path is equal to the number of relationships in the path\n",
    "- Paths can be captured as variables and used elsewhere in queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5ba5fe2-060e-4de1-ba2e-63ba1e9f83a7",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "    MATCH window = (c1:Chunk)-[:NEXT]->(c2:Chunk)-[:NEXT]->(c3:Chunk) \n",
    "        WHERE c1.chunkId = $chunkIdParam\n",
    "    RETURN length(window) as windowPathLength\n",
    "    \"\"\"\n",
    "\n",
    "kg.query(cypher,\n",
    "         params={'chunkIdParam': next_chunk_info['chunkId']})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8888d181-ff5d-46c9-b362-47631118015b",
   "metadata": {},
   "source": [
    "### Finding variable length windows\n",
    "- A pattern match will fail if the relationship doesn't exist in the graph\n",
    "- For example, the first chunk in a section has no preceding chunk, so the next query won't return anything"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b205a40-0439-458f-80ba-99511e971cbf",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "    MATCH window=(c1:Chunk)-[:NEXT]->(c2:Chunk)-[:NEXT]->(c3:Chunk) \n",
    "        WHERE c2.chunkId = $chunkIdParam\n",
    "    RETURN nodes(window) as chunkList\n",
    "    \"\"\"\n",
    "# pull the chunk ID from the first \n",
    "kg.query(cypher,\n",
    "         params={'chunkIdParam': first_chunk_info['chunkId']})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58d9463d-d6ac-4a41-90bf-fde5ff83062f",
   "metadata": {},
   "source": [
    "- Modify `NEXT` relationship to have variable length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10aa6d79-0891-45dd-bad1-f76587e2f6b8",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH window=\n",
    "      (:Chunk)-[:NEXT*0..1]->(c:Chunk)-[:NEXT*0..1]->(:Chunk) \n",
    "    WHERE c.chunkId = $chunkIdParam\n",
    "  RETURN length(window)\n",
    "  \"\"\"\n",
    "\n",
    "kg.query(cypher,\n",
    "         params={'chunkIdParam': first_chunk_info['chunkId']})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9707a310-05c3-494b-ae5a-946aefcb9a29",
   "metadata": {},
   "source": [
    "- Retrieve only the longest path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb9a14f5-683e-49e7-a2b6-70942c29878c",
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH window=\n",
    "      (:Chunk)-[:NEXT*0..1]->(c:Chunk)-[:NEXT*0..1]->(:Chunk)\n",
    "    WHERE c.chunkId = $chunkIdParam\n",
    "  WITH window as longestChunkWindow \n",
    "      ORDER BY length(window) DESC LIMIT 1\n",
    "  RETURN length(longestChunkWindow)\n",
    "  \"\"\"\n",
    "\n",
    "kg.query(cypher,\n",
    "         params={'chunkIdParam': first_chunk_info['chunkId']})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87c9851b-7b82-4f9b-82cb-a61a41307867",
   "metadata": {},
   "source": [
    "### Customize the results of the similarity search using Cypher\n",
    "- Extend the vector store definition to accept a Cypher query\n",
    "- The Cypher query takes the results of the vector similarity search and then modifies them in some way\n",
    "- Start with a simple query that just returns some extra text along with the search results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a4881ae-ceb9-4276-aeeb-02a246b2f24b",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "retrieval_query_extra_text = \"\"\"\n",
    "WITH node, score, \"Andreas knows Cypher. \" as extraText\n",
    "RETURN extraText + \"\\n\" + node.text as text,\n",
    "    score,\n",
    "    node {.source} AS metadata\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab40c207-387a-4505-90f1-113ca9cf7336",
   "metadata": {},
   "source": [
    "- Set up the vector store to use the query, then instantiate a retriever and Question-Answer chain in LangChain\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7416b649-4880-410e-a73b-3f8493bfa55d",
   "metadata": {
    "height": 352
   },
   "outputs": [],
   "source": [
    "vector_store_extra_text = Neo4jVector.from_existing_index(\n",
    "    embedding=OpenAIEmbeddings(),\n",
    "    url=NEO4J_URI,\n",
    "    username=NEO4J_USERNAME,\n",
    "    password=NEO4J_PASSWORD,\n",
    "    database=\"neo4j\",\n",
    "    index_name=VECTOR_INDEX_NAME,\n",
    "    text_node_property=VECTOR_SOURCE_PROPERTY,\n",
    "    retrieval_query=retrieval_query_extra_text, # NEW !!!\n",
    ")\n",
    "\n",
    "# Create a retriever from the vector store\n",
    "retriever_extra_text = vector_store_extra_text.as_retriever()\n",
    "\n",
    "# Create a chatbot Question & Answer chain from the retriever\n",
    "chain_extra_text = RetrievalQAWithSourcesChain.from_chain_type(\n",
    "    ChatOpenAI(temperature=0), \n",
    "    chain_type=\"stuff\", \n",
    "    retriever=retriever_extra_text\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23e83948-d718-4c46-8990-402151093bad",
   "metadata": {},
   "source": [
    "- Ask a question!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98072c4e-ee1d-47e8-9add-8cd6a52fa6d2",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "chain_extra_text(\n",
    "    {\"question\": \"What topics does Andreas know about?\"},\n",
    "    return_only_outputs=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3118062-d7ed-4420-830c-75a7e10a403f",
   "metadata": {},
   "source": [
    "- Note, the LLM hallucinates here, using the information in the retrieved text as well as the extra text.\n",
    "- Modify the prompt to try and get a more accurate answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c604b1bb-ab8d-4b53-ab83-1190c49357cd",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "chain_extra_text(\n",
    "    {\"question\": \"What single topic does Andreas know about?\"},\n",
    "    return_only_outputs=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b60499f-eee9-4d64-8e96-330bba8db7d2",
   "metadata": {},
   "source": [
    "### Try for yourself!\n",
    "- Modify the query below to add your own additional text\n",
    "- Try engineering the prompt to refine your results\n",
    "- Note, you'll need to reset the vector store, retriever, and chain each time you change the Cypher query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd6a2db5-d451-42ba-9327-9bbb8ba22960",
   "metadata": {
    "height": 488
   },
   "outputs": [],
   "source": [
    "# modify the retrieval extra text here then run the entire cell\n",
    "retrieval_query_extra_text = \"\"\"\n",
    "WITH node, score, \"Andreas knows Cypher. \" as extraText\n",
    "RETURN extraText + \"\\n\" + node.text as text,\n",
    "    score,\n",
    "    node {.source} AS metadata\n",
    "\"\"\"\n",
    "\n",
    "vector_store_extra_text = Neo4jVector.from_existing_index(\n",
    "    embedding=OpenAIEmbeddings(),\n",
    "    url=NEO4J_URI,\n",
    "    username=NEO4J_USERNAME,\n",
    "    password=NEO4J_PASSWORD,\n",
    "    database=\"neo4j\",\n",
    "    index_name=VECTOR_INDEX_NAME,\n",
    "    text_node_property=VECTOR_SOURCE_PROPERTY,\n",
    "    retrieval_query=retrieval_query_extra_text, # NEW !!!\n",
    ")\n",
    "\n",
    "# Create a retriever from the vector store\n",
    "retriever_extra_text = vector_store_extra_text.as_retriever()\n",
    "\n",
    "# Create a chatbot Question & Answer chain from the retriever\n",
    "chain_extra_text = RetrievalQAWithSourcesChain.from_chain_type(\n",
    "    ChatOpenAI(temperature=0), \n",
    "    chain_type=\"stuff\", \n",
    "    retriever=retriever_extra_text\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2bc49fa-4a7e-42e1-8703-bd455a4c0395",
   "metadata": {},
   "source": [
    "### Expand context around a chunk using a window\n",
    "- First, create a regular vector store that retrieves a single node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "476e133e-86f1-45e3-90fb-0fba7bb638c6",
   "metadata": {
    "height": 335
   },
   "outputs": [],
   "source": [
    "neo4j_vector_store = Neo4jVector.from_existing_graph(\n",
    "    embedding=OpenAIEmbeddings(),\n",
    "    url=NEO4J_URI,\n",
    "    username=NEO4J_USERNAME,\n",
    "    password=NEO4J_PASSWORD,\n",
    "    index_name=VECTOR_INDEX_NAME,\n",
    "    node_label=VECTOR_NODE_LABEL,\n",
    "    text_node_properties=[VECTOR_SOURCE_PROPERTY],\n",
    "    embedding_node_property=VECTOR_EMBEDDING_PROPERTY,\n",
    ")\n",
    "# Create a retriever from the vector store\n",
    "windowless_retriever = neo4j_vector_store.as_retriever()\n",
    "\n",
    "# Create a chatbot Question & Answer chain from the retriever\n",
    "windowless_chain = RetrievalQAWithSourcesChain.from_chain_type(\n",
    "    ChatOpenAI(temperature=0), \n",
    "    chain_type=\"stuff\", \n",
    "    retriever=windowless_retriever\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7002df5-aba8-4ae6-924e-8b8275cfc68d",
   "metadata": {},
   "source": [
    "- Next, define a window retrieval query to get consecutive chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a21b72a6-fb6d-4c10-856e-7b2b76384a52",
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "retrieval_query_window = \"\"\"\n",
    "MATCH window=\n",
    "    (:Chunk)-[:NEXT*0..1]->(node)-[:NEXT*0..1]->(:Chunk)\n",
    "WITH node, score, window as longestWindow \n",
    "  ORDER BY length(window) DESC LIMIT 1\n",
    "WITH nodes(longestWindow) as chunkList, node, score\n",
    "  UNWIND chunkList as chunkRows\n",
    "WITH collect(chunkRows.text) as textList, node, score\n",
    "RETURN apoc.text.join(textList, \" \\n \") as text,\n",
    "    score,\n",
    "    node {.source} AS metadata\n",
    "\"\"\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bf302d0-ea5a-42c0-bf1c-7d51db79768b",
   "metadata": {},
   "source": [
    "- Set up a QA chain that will use the window retrieval query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a326661-2bd4-47f6-be30-ce9c49d9515d",
   "metadata": {
    "height": 352
   },
   "outputs": [],
   "source": [
    "vector_store_window = Neo4jVector.from_existing_index(\n",
    "    embedding=OpenAIEmbeddings(),\n",
    "    url=NEO4J_URI,\n",
    "    username=NEO4J_USERNAME,\n",
    "    password=NEO4J_PASSWORD,\n",
    "    database=\"neo4j\",\n",
    "    index_name=VECTOR_INDEX_NAME,\n",
    "    text_node_property=VECTOR_SOURCE_PROPERTY,\n",
    "    retrieval_query=retrieval_query_window, # NEW!!!\n",
    ")\n",
    "\n",
    "# Create a retriever from the vector store\n",
    "retriever_window = vector_store_window.as_retriever()\n",
    "\n",
    "# Create a chatbot Question & Answer chain from the retriever\n",
    "chain_window = RetrievalQAWithSourcesChain.from_chain_type(\n",
    "    ChatOpenAI(temperature=0), \n",
    "    chain_type=\"stuff\", \n",
    "    retriever=retriever_window\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c6d48a5-6a1e-4fff-a739-3bbbaf2879e0",
   "metadata": {},
   "source": [
    "### Compare the two chains"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "024dcfef-5101-48b3-b1ce-21aa8284f81d",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "question = \"In a single sentence, tell me about Netapp's business.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f5115eb-eba1-4035-a91d-e47b7ed789e3",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "answer = windowless_chain(\n",
    "    {\"question\": question},\n",
    "    return_only_outputs=True,\n",
    ")\n",
    "print(textwrap.fill(answer[\"answer\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a1c4774-d131-4636-ae88-70c2ddd618b7",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "answer = chain_window(\n",
    "    {\"question\": question},\n",
    "    return_only_outputs=True,\n",
    ")\n",
    "print(textwrap.fill(answer[\"answer\"]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
